{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ares Demo\n",
    "\n",
    "Example on running Ares."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from ares import construct"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quickstart (simple)\n",
    "\n",
    "If `ares` is installed and the config file is valid, simply run the command\n",
    "\n",
    "```sh\n",
    "$ python -m ares /path/to/config.json\n",
    "```\n",
    "\n",
    "This will run simulation and automatically output the commonly used results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! cd .. && python -m ares configs/projected_gradient_descent.json"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quickstart (intermediate)\n",
    "\n",
    "For more choice on running the simulation, the environment needs to be constructed and run manually. The following steps will guide through the process."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Config File\n",
    "\n",
    "First, load the configuration for the simulation. You can either read the config file from a JSON file or manually define the dictionary in Python. Ensure the path to the defense models and checkpoints are correct."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"attacker\": {\n",
    "        \"attacks\": [\n",
    "            {\n",
    "                \"type\": \"evasion\",\n",
    "                \"name\": \"ProjectedGradientDescent\",\n",
    "                \"params\": {\n",
    "                    \"norm\": \"inf\",\n",
    "                    \"eps\": 0.03137254901,\n",
    "                    \"eps_step\": 0.00784313725,\n",
    "                    \"num_random_init\": 0,\n",
    "                    \"max_iter\": 1,\n",
    "                    \"targeted\": False,\n",
    "                    \"batch_size\": 1,\n",
    "                    \"verbose\": False\n",
    "                }\n",
    "            }\n",
    "        ]\n",
    "    },\n",
    "    \"defender\": {\n",
    "        \"probabilities\": [0.3, 0.7],\n",
    "        \"models\": [\n",
    "            {\n",
    "                \"file\": \"../data/models/resnet.py\",\n",
    "                \"name\": \"resnet18_nat\",\n",
    "                \"params\": {\n",
    "                    \"num_classes\": 10\n",
    "                },\n",
    "                \"checkpoint\": \"../data/state_dicts/resnet18_nat.pth\"\n",
    "            },\n",
    "            {\n",
    "                \"file\": \"../data/models/resnet.py\",\n",
    "                \"name\": \"resnet18_adv\",\n",
    "                \"params\": {\n",
    "                    \"num_classes\": 10\n",
    "                },\n",
    "                \"checkpoint\": \"../data/state_dicts/resnet18_adv.pth\"\n",
    "            }\n",
    "        ]\n",
    "    },\n",
    "    \"scenario\": {\n",
    "        \"num_episodes\": 50,\n",
    "        \"max_rounds\": 50,\n",
    "        \"dataset\": {\n",
    "            \"name\": \"cifar10\",\n",
    "            \"input_shape\": [3, 32, 32],\n",
    "            \"num_classes\": 10,\n",
    "            \"clip_values\": [0, 1],\n",
    "            \"random_noise\": True,\n",
    "            \"dataroot\": \"../data/cifar10\"\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Construct Environment\n",
    "\n",
    "Next, create the simulation environment. A helper function `construct(config)` is provided which will automatically create the attacker, defender, and scenario components based on the config file. This will initialize the `gym` environment which can be run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create environment\n",
    "env = construct(config)\n",
    "env"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run Experiment\n",
    "\n",
    "Once the environment is created, run the simulation for the number of episodes desired. This is always specified in the scenario component. Record the number of rounds and queries until the attacker wins for each trial or the defender wins by a timeout."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "episode_rewards = []\n",
    "episode_queries = []\n",
    "for episode in range(env.scenario.num_episodes):\n",
    "    print(f\"=== Episode {episode + 1} ===\")\n",
    "\n",
    "    # initialize environment\n",
    "    observation = env.reset()\n",
    "    x = observation[\"x\"]\n",
    "    y = observation[\"y\"]\n",
    "    done = False\n",
    "\n",
    "    # run simulation\n",
    "    while not done:\n",
    "        action = {\n",
    "            \"x\": x,\n",
    "            \"y\": y,\n",
    "        }\n",
    "        observation, reward, done, info = env.step(action)\n",
    "        defense = observation[\"defense\"]\n",
    "        attack = observation[\"attack\"]\n",
    "        x = observation[\"x_adv\"]\n",
    "        y_pred = observation[\"y_pred\"]\n",
    "        eps = observation[\"eps\"]\n",
    "        evaded = observation[\"evaded\"]\n",
    "        detected = observation[\"detected\"]\n",
    "        winner = observation[\"winner\"]\n",
    "        step_count = info[\"step_count\"]\n",
    "        queries = info[\"queries\"]\n",
    "\n",
    "        print(f\"Round {step_count:2}: defense = {defense}, attack = {attack}\")\n",
    "        print(f\"\\t [label = {y[0]} | pred = {y_pred[0]}], eps = {eps:.6f}, queries = {queries}\")\n",
    "        if evaded:\n",
    "            print(\"\\t attacker evaded\")\n",
    "        elif detected:\n",
    "            print(\"\\t attacker was detected\")\n",
    "\n",
    "        print(f\"Game end: {winner} wins after {reward} rounds and {queries} queries\")\n",
    "        episode_rewards.append(reward)\n",
    "        episode_queries.append(queries)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calculate Statistics\n",
    "\n",
    "Finally, calculate the statistics and create the visualizations for the experiment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_mean = np.mean(episode_rewards)\n",
    "reward_stddev = np.std(episode_rewards)\n",
    "reward_median = np.median(episode_rewards)\n",
    "print(f\"Rounds:  mean = {reward_mean}, stddev = {reward_stddev:.3f}, median = {reward_median}\")\n",
    "\n",
    "queries_mean = np.mean(episode_queries)\n",
    "queries_stddev = np.std(episode_queries)\n",
    "queries_median = np.median(episode_queries)\n",
    "print(f\"Queries: mean = {queries_mean}, stddev = {queries_stddev:.3f}, median = {queries_median}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 2, figsize=(10, 3))\n",
    "ax[0].hist(episode_rewards)\n",
    "ax[0].set(title='Distribution of Episode Rounds', xlabel='Rounds', ylabel='Count')\n",
    "ax[1].hist(episode_queries)\n",
    "ax[1].set(title='Distribution of Episode Queries', xlabel='Queries', ylabel='Count')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quickstart (advanced)\n",
    "\n",
    "For full control over the entire simulation, each of the components for the simulation environment can be constructed manually. This is not recommended unless full control is needed."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.6 ('venv': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f45676dd1ecac82831d6a9330ee4931144303af33e13967da272a8c3f6810f3d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
